{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import math\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x): \n",
    "    return 1. / (1 + np.exp(-x))\n",
    "\n",
    "# createst uniform random array w/ values in [a,b) and shape args\n",
    "def rand_arr(a, b, *args): \n",
    "    np.random.seed(0)\n",
    "    return np.random.rand(*args) * (b - a) + a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LstmParam:\n",
    "    def __init__(self, mem_cell_ct, x_dim):\n",
    "        self.mem_cell_ct = mem_cell_ct #lstm神经元的数目\n",
    "        self.x_dim = x_dim # 输入数据的维度\n",
    "        concat_len = x_dim + mem_cell_ct \n",
    "        # weight matrices\n",
    "        self.wg = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len)\n",
    "        self.wi = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len) \n",
    "        self.wf = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len)\n",
    "        self.wo = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len)\n",
    "        # bias terms\n",
    "        self.bg = rand_arr(-0.1, 0.1, mem_cell_ct) \n",
    "        self.bi = rand_arr(-0.1, 0.1, mem_cell_ct) \n",
    "        self.bf = rand_arr(-0.1, 0.1, mem_cell_ct) \n",
    "        self.bo = rand_arr(-0.1, 0.1, mem_cell_ct) \n",
    "        # diffs (derivative of loss function w.r.t. all parameters)\n",
    "        self.wg_diff = np.zeros((mem_cell_ct, concat_len)) \n",
    "        self.wi_diff = np.zeros((mem_cell_ct, concat_len)) \n",
    "        self.wf_diff = np.zeros((mem_cell_ct, concat_len)) \n",
    "        self.wo_diff = np.zeros((mem_cell_ct, concat_len)) \n",
    "        self.bg_diff = np.zeros(mem_cell_ct) \n",
    "        self.bi_diff = np.zeros(mem_cell_ct) \n",
    "        self.bf_diff = np.zeros(mem_cell_ct) \n",
    "        self.bo_diff = np.zeros(mem_cell_ct) \n",
    "\n",
    "    def apply_diff(self, lr = 1):\n",
    "        self.wg -= lr * self.wg_diff\n",
    "        self.wi -= lr * self.wi_diff\n",
    "        self.wf -= lr * self.wf_diff\n",
    "        self.wo -= lr * self.wo_diff\n",
    "        self.bg -= lr * self.bg_diff\n",
    "        self.bi -= lr * self.bi_diff\n",
    "        self.bf -= lr * self.bf_diff\n",
    "        self.bo -= lr * self.bo_diff\n",
    "        # reset diffs to zero\n",
    "        self.wg_diff = np.zeros_like(self.wg)\n",
    "        self.wi_diff = np.zeros_like(self.wi) \n",
    "        self.wf_diff = np.zeros_like(self.wf) \n",
    "        self.wo_diff = np.zeros_like(self.wo) \n",
    "        self.bg_diff = np.zeros_like(self.bg)\n",
    "        self.bi_diff = np.zeros_like(self.bi) \n",
    "        self.bf_diff = np.zeros_like(self.bf) \n",
    "        self.bo_diff = np.zeros_like(self.bo) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LstmState:\n",
    "    def __init__(self, mem_cell_ct, x_dim):\n",
    "        self.g = np.zeros(mem_cell_ct)\n",
    "        self.i = np.zeros(mem_cell_ct)\n",
    "        self.f = np.zeros(mem_cell_ct)\n",
    "        self.o = np.zeros(mem_cell_ct)\n",
    "        self.s = np.zeros(mem_cell_ct)\n",
    "        self.h = np.zeros(mem_cell_ct)\n",
    "        self.bottom_diff_h = np.zeros_like(self.h)\n",
    "        self.bottom_diff_s = np.zeros_like(self.s)\n",
    "        self.bottom_diff_x = np.zeros(x_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [],
   "source": [
    "    \n",
    "class LstmNode:\n",
    "    def __init__(self, lstm_param, lstm_state):\n",
    "        # store reference to parameters and to activations\n",
    "        self.state = lstm_state\n",
    "        self.param = lstm_param\n",
    "        # non-recurrent input to node\n",
    "        self.x = None\n",
    "        # non-recurrent input concatenated with recurrent input\n",
    "        self.xc = None\n",
    "\n",
    "    def bottom_data_is(self, x, s_prev = None, h_prev = None, flag = 0):\n",
    "        \n",
    "        # if this is the first lstm node in the network\n",
    "        if flag == 0:\n",
    "            s_prev = np.zeros_like(self.state.s)\n",
    "            h_prev = np.zeros_like(self.state.h)\n",
    "        # save data for use in backprop\n",
    "        self.s_prev = s_prev\n",
    "        self.h_prev = h_prev\n",
    "\n",
    "        # concatenate x(t) and h(t-1)\n",
    "        xc = np.hstack((x,  h_prev))\n",
    "        self.state.g = np.tanh(np.dot(self.param.wg, xc) + self.param.bg)\n",
    "        self.state.i = sigmoid(np.dot(self.param.wi, xc) + self.param.bi)\n",
    "        self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf)\n",
    "        self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo)\n",
    "        self.state.s = self.state.g * self.state.i + s_prev * self.state.f\n",
    "        self.state.h = self.state.s * self.state.o\n",
    "        self.x = x\n",
    "        self.xc = xc\n",
    "    \n",
    "    def top_diff_is(self, top_diff_h, top_diff_s):\n",
    "        # top_diff_h : dL(t)/dh(t) = dl(t)/dh(t) + dL(t+1)/dh(t)\n",
    "        # top_diff_s : dL(t+1)/ds(t)\n",
    "        # notice that top_diff_s is carried along the constant error carousel（ 常量错误木马）\n",
    "        ds = self.state.o * top_diff_h + top_diff_s\n",
    "        do = self.state.s * top_diff_h\n",
    "        di = self.state.g * ds\n",
    "        dg = self.state.i * ds\n",
    "        df = self.s_prev * ds\n",
    "\n",
    "        # diffs w.r.t. vector inside sigma / tanh function\n",
    "        di_input = (1. - self.state.i) * self.state.i * di \n",
    "        df_input = (1. - self.state.f) * self.state.f * df \n",
    "        do_input = (1. - self.state.o) * self.state.o * do \n",
    "        dg_input = (1. - self.state.g ** 2) * dg\n",
    "\n",
    "        # diffs w.r.t. inputs\n",
    "        self.param.wi_diff += np.outer(di_input, self.xc)\n",
    "        self.param.wf_diff += np.outer(df_input, self.xc)\n",
    "        self.param.wo_diff += np.outer(do_input, self.xc)\n",
    "        self.param.wg_diff += np.outer(dg_input, self.xc)\n",
    "        self.param.bi_diff += di_input\n",
    "        self.param.bf_diff += df_input       \n",
    "        self.param.bo_diff += do_input\n",
    "        self.param.bg_diff += dg_input       \n",
    "\n",
    "        # compute bottom diff\n",
    "        dxc = np.zeros_like(self.xc)\n",
    "        dxc += np.dot(self.param.wi.T, di_input)\n",
    "        dxc += np.dot(self.param.wf.T, df_input)\n",
    "        dxc += np.dot(self.param.wo.T, do_input)\n",
    "        dxc += np.dot(self.param.wg.T, dg_input)\n",
    "\n",
    "        # save bottom diffs\n",
    "        self.state.bottom_diff_s = ds * self.state.f\n",
    "        self.state.bottom_diff_x = dxc[:self.param.x_dim]\n",
    "        self.state.bottom_diff_h = dxc[self.param.x_dim:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LstmNetwork():\n",
    "    def __init__(self, lstm_param):\n",
    "        self.lstm_param = lstm_param\n",
    "        self.lstm_node_list = []\n",
    "        # input sequence\n",
    "        self.x_list = []\n",
    "\n",
    "    def y_list_is(self, y_list, loss_layer):\n",
    "        \"\"\"\n",
    "        Updates diffs by setting target sequence \n",
    "        with corresponding loss layer. \n",
    "        Will *NOT* update parameters.  To update parameters,\n",
    "        call self.lstm_param.apply_diff()\n",
    "        \"\"\" \n",
    "        assert len(y_list) == len(self.x_list)\n",
    "        idx = len(self.x_list) - 1\n",
    "        # first node only gets diffs from label ...\n",
    "        loss = loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx])\n",
    "        diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx])\n",
    "        # here s is not affecting loss due to h(t+1) ??????, hence we set equal to zero\n",
    "        diff_s = np.zeros(self.lstm_param.mem_cell_ct)\n",
    "        self.lstm_node_list[idx].top_diff_is(diff_h, diff_s)\n",
    "        idx -= 1\n",
    "\n",
    "        ### ... following nodes also get diffs from next nodes, hence we add diffs to diff_h\n",
    "        ### we also propagate error along constant error carousel using diff_s\n",
    "        while idx >= 0:\n",
    "            loss += loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx])\n",
    "            diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx])\n",
    "            diff_h += self.lstm_node_list[idx + 1].state.bottom_diff_h # diff_h是一个递归的\n",
    "            diff_s = self.lstm_node_list[idx + 1].state.bottom_diff_s\n",
    "            self.lstm_node_list[idx].top_diff_is(diff_h, diff_s)\n",
    "            idx -= 1 \n",
    "\n",
    "        return loss\n",
    "\n",
    "    def x_list_clear(self):\n",
    "        self.x_list = []\n",
    "\n",
    "    def x_list_add(self, x):\n",
    "        self.x_list.append(x)\n",
    "        if len(self.x_list) > len(self.lstm_node_list):\n",
    "            # need to add new lstm node, create new state mem\n",
    "            lstm_state = LstmState(self.lstm_param.mem_cell_ct, self.lstm_param.x_dim)\n",
    "            self.lstm_node_list.append(LstmNode(self.lstm_param, lstm_state))\n",
    "\n",
    "        # get index of most recent x input\n",
    "        idx = len(self.x_list) - 1\n",
    "        if idx == 0:\n",
    "            # no recurrent inputs yet\n",
    "            self.lstm_node_list[idx].bottom_data_is(x,flag = idx)\n",
    "        else:\n",
    "            s_prev = self.lstm_node_list[idx - 1].state.s\n",
    "            h_prev = self.lstm_node_list[idx - 1].state.h\n",
    "            self.lstm_node_list[idx].bottom_data_is(x, s_prev, h_prev,flag = idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ToyLossLayer:\n",
    "    \"\"\"\n",
    "    Computes square loss with first element of hidden layer array.\n",
    "    \"\"\"\n",
    "    @classmethod\n",
    "    def loss(self, pred, label):\n",
    "        return (pred[0] - label) ** 2\n",
    "\n",
    "    @classmethod\n",
    "    def bottom_diff(self, pred, label):\n",
    "        diff = np.zeros_like(pred)\n",
    "        diff[0] = 2 * (pred[0] - label)\n",
    "        return diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f36f5a62be0>]"
      ]
     },
     "execution_count": 148,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAcYklEQVR4nO3dfXBc13nf8e/DxYJcQJRAipBtgpRJ2zQVynphitKJpTa21ISUnJq04kyoNG3SuMNha+XFbdiAdcczmdQjeZg2dmopHI6i2n0Z0x6bgZmINpxYSZxatktIZCTREhRKckyAsgWLgkiKAPH29I+7Cy6Xu9gLYBe799zfZ4aj3bsXu+cI4I8Hzzn3XHN3REQk+ZY0ugEiIlIbCnQRkUAo0EVEAqFAFxEJhAJdRCQQLY364FWrVvm6desa9fEiIon0xBNP/NjdO8u91rBAX7duHf39/Y36eBGRRDKzf6j0mkouIiKBUKCLiARCgS4iEggFuohIIBToIiKBaNgql/noPTbEvr4BTo+Msrojx56tG9mxuavRzRIRaQqJCfTeY0PsPfQ0oxNTAAyNjLL30NMACnURERJUctnXNzAT5gWjE1Ps6xtoUItERJpLYgL99MjonI6LiKRNYgJ9dUduTsdFRNImMYG+Z+tGctnMZcdy2Qx7tm5sUItERJpLYiZFCxOfD3z1OX54doxrcll+7wM3akJURCQvMSN0iEL9b3/3fQD8+m3rFeYiIkUSFegA2cwSli9r4bUL441uiohIU0lcoAOsaGtVoIuIlEhmoLe38tqFiUY3Q0SkqSQz0NuyvPaGRugiIsUSGegrVXIREblCIgO9o61VI3QRkRKxAt3MtpnZgJmdNLOeMq/vMbPj+T/PmNmUma2sfXMjK9qyvDE+xcXJqeoni4ikRNVAN7MM8CBwF7AJuNfMNhWf4+773P1Wd78V2Av8jbufqUN7gWhSFGBEE6MiIjPijNC3ACfd/UV3HwcOAttnOf9e4PO1aFwlK9qiQFcdXUTkkjiB3gWcKno+mD92BTNrA7YBX67w+i4z6zez/uHh4bm2dcaK9iwAZ1RHFxGZESfQrcwxr3DuPwe+Vanc4u4H3L3b3bs7OzvjtvEKhRG6Si4iIpfECfRBYG3R8zXA6Qrn7qTO5RaAle0quYiIlIoT6EeBDWa23sxaiUL7cOlJZnYN8DPAV2rbxCt1tEUlFy1dFBG5pOr2ue4+aWb3AX1ABnjE3U+Y2e786/vzp34Q+Lq7v1G31uYtbcnQ3prR5f8iIkVi7Yfu7keAIyXH9pc8/yzw2Vo1rBpdXCQicrlEXikK0UoX1dBFRC5JbqC3tXJGJRcRkRmJDvQRjdBFRGYkNtBXtquGLiJSLLGB3tGW5ezYJJNT041uiohIU0hsoBcuLhoZVR1dRAQSHOgdhQ26VHYREQESHOgrCleLaqWLiAiQ6ECPRujacVFEJJLcQJ+5yYUCXUQEEhzoKwsjdAW6iAiQ4EDPtWZY2rJEe6KLiOQlNtBBFxeJiBRLdKB3tLVqgy4RkbxEB/rK9qyWLYqI5CU60LUnuojIJYkO9BVt2hNdRKQgVqCb2TYzGzCzk2bWU+Gc95rZcTM7YWZ/U9tmlreyrZWR0Qmmpn0xPk5EpKlVDXQzywAPAncBm4B7zWxTyTkdwEPAB9z9RuAXa9/Uy/UeG+Jz3/4+7nD7Jx+j99hQvT9SRKSpxRmhbwFOuvuL7j4OHAS2l5zzy8Ahd/8BgLu/UttmXq732BB7Dz3N66OTALz8+hh7Dz2tUBeRVIsT6F3AqaLng/ljxd4JrDCzvzazJ8zsX9WqgeXs6xtgdGLqsmOjE1Ps6xuo58eKiDS1lhjnWJljpUXrFuAfAXcCOeDbZvYdd3/+sjcy2wXsArj++uvn3tq80yOjczouIpIGcUbog8DaoudrgNNlzvmau7/h7j8GvgncUvpG7n7A3bvdvbuzs3O+bWZ1R25Ox0VE0iBOoB8FNpjZejNrBXYCh0vO+QrwT8ysxczagHcDz9a2qZfs2bqRXDZz2bFcNsOerRvr9ZEiIk2vasnF3SfN7D6gD8gAj7j7CTPbnX99v7s/a2ZfA54CpoGH3f2ZejV6x+aohL+v7zmGRsa4amkL/2XHu2aOi4ikkbk3Zg13d3e39/f3L/h9tnziL7njhut44BdurkGrRESam5k94e7d5V5L9JWiEN25SHctEhEJIdDbs9oTXUSEEAK9rVV3LRIRIYRAb2/VfUVFRAgh0NuiPdEbNbkrItIsAgj0VqamnbNjk41uiohIQwUR6IBudCEiqZf8QG/PAuhGFyKSeskP9MIIXYEuIikXTqC/obXoIpJuyQ/0do3QRUQggEC/elkLmSWmQBeR1Et8oJsZK9qynFHJRURSLvGBDtDRpqtFRUSCCPSV2nFRRCSMQO9o046LIiJBBPrK9lZNiopI6gUR6B1tUaBrgy4RSbNYgW5m28xswMxOmllPmdffa2avm9nx/J+P176pla1szzIx5bwxPrWYHysi0lSq3iTazDLAg8DPAoPAUTM77O7fKzn1b9395+vQxqo6ijboumpp1S6JiAQpzgh9C3DS3V9093HgILC9vs2am5Xaz0VEJFagdwGnip4P5o+V+mkz+zsz+6qZ3Vjujcxsl5n1m1n/8PDwPJpbXmHHRS1dFJE0ixPoVuZY6ezjk8Bb3f0W4L8DveXeyN0PuHu3u3d3dnbOqaGzKWzQpaWLIpJmcQJ9EFhb9HwNcLr4BHc/6+7n84+PAFkzW1WzVlZRCHSN0EUkzeIE+lFgg5mtN7NWYCdwuPgEM3uzmVn+8Zb8+75a68ZWcnUuyxJDl/+LSKpVXRLi7pNmdh/QB2SAR9z9hJntzr++H/gQ8G/NbBIYBXb6Ii4KzywxrsllOaNAF5EUi7XGL19GOVJybH/R488An6lt0+ZmRXsrr6mGLiIpFsSVohDV0XWjaBFJs7ACXSN0EUmxgAI9qxG6iKRaMIFe2HFRG3SJSFoFE+gdba1cnJxmdEIbdIlIOgUT6N9/9TwAN368j9seeIzeY0MNbpGIyOIKItB7jw3xp09GAe7A0Mgoew89rVAXkVQJItD39Q0wPnV57Xx0Yop9fQMNapGIyOILItBPj4zO6biISIiCCPTVHbk5HRcRCVEQgb5n60Zy2cwVx4dGRjVBKiKpEcT92nZsju63sa9vgKGSMkthgrT4PBGREAUxQocorL/VcwddZcosmiAVkTQIJtALKk2EqvwiIqELLtBnmwjV+nQRCVlwgV5pgrRA5RcRCVUQk6LFZpsgLdD6dBEJUawRupltM7MBMztpZj2znPePzWzKzD5UuybO3WwTpBBtD6B6uoiEpmqgm1kGeBC4C9gE3Gtmmyqc90mie482hdnKL6qni0ho4ozQtwAn3f1Fdx8HDgLby5z3G8CXgVdq2L4F2bG5i/vvuaniSF31dBEJSZxA7wJOFT0fzB+bYWZdwAeB/czCzHaZWb+Z9Q8PD8+1rfNSKL9YhddVTxeRUMQJ9HJZWHpboE8Bv+vus95dwt0PuHu3u3d3dnbGbGJtVFrOqHq6iIQiTqAPAmuLnq8BTpec0w0cNLPvAx8CHjKzHbVoYK2oni4ioYsT6EeBDWa23sxagZ3A4eIT3H29u69z93XAl4B/5+69tW7sQqieLiKhqxro7j4J3Ee0euVZ4IvufsLMdpvZ7no3sJaq1dO1PYCIJFmsC4vc/QhwpORY2QlQd/+1hTervlZ35CpedKTdGUUkqYK79D8ObQ8gIiEK7tL/OLQ9gIiEKJUjdND2ACISntQGeoGWM4pIKFIf6FrOKCKhSH2gg5YzikgYFOhFdLcjEUkyBXoRLWcUkSRL5bLFSrScUUSSTCP0ElrOKCJJpUCvQMsZRSRpFOgVaDmjiCSNAn0WWs4oIkmiQI9ByxlFJAkU6DFoOaOIJIGWLcag5YwikgQaocek5YzJ1XtsiNseeIz1PY/qeyRBixXoZrbNzAbM7KSZ9ZR5fbuZPWVmx82s38xur31Tm4OWMyZL77Eh9h56mqGRURx9jyRsVQPdzDLAg8BdwCbgXjPbVHLaN4Bb3P1W4NeBh2vczqah5YzJsq9vgNGJqcuO6XskoYozQt8CnHT3F919HDgIbC8+wd3Pu7vnn7YTVSCCVW05o+rpzaPS90LfIwlRnEDvAk4VPR/MH7uMmX3QzJ4DHiUapV/BzHblSzL9w8PD82lvU6m0nFH19OZR6Xs021JUkaSKE+jlBqJXjMDd/U/d/QZgB/D75d7I3Q+4e7e7d3d2ds6poc1I9fTmt2frRlpbLv8xz2Uz7Nm6sUEtEqmfOIE+CKwter4GOF3pZHf/JvB2M1u1wLY1PdXTm9+OzV3cs/nSL5RdHTnuv+emmaWoIiGJE+hHgQ1mtt7MWoGdwOHiE8zsHWZm+cc/CbQCr9a6sc1I9fTm17l8KQC3ru3gWz13KMwlWFUD3d0ngfuAPuBZ4IvufsLMdpvZ7vxpvwA8Y2bHiVbE/FLRJGkqqJ7evE6duQDA6PhUlTNFki3WlaLufgQ4UnJsf9HjTwKfrG3TkmXP1o3sPfT0FUvk4FI9HdDosAFOvRb9llTueyMSEl0pWiOqpzevH+RH6Bc0QpfAaS+XGtqxuYsdm7tY3/No2YX4qqfXV++xIfb1DXB6ZJTVHTn2bN3Itne9meFzFwEYHZ9scAtF6kuBXgerO3JlN/Eyg/U9j86EjcovtVO4xL9QVimUuX50Nvo+rF2ZY+i1Udyd/Py9SHBUcqmDSuvTpx3tJ1InlS7xP/DNlwDY+KblTDtcnJxuRPNEFoUCvQ6K6+kGZMqMCFVTr61K5axX3xgHYMOblgNa6SJhU8mlTgr1dIjKLOWopl47lcpcVy3NMDHlXL+yDYj+IV2x2I0TWSQaoS8CrVGvvz1bN7K0zCX+b1vVzpoVOdpaoxKYVrpIyBToi0B7vtTfjs1dfPj29TPPr21v5f57bmLKYe3KNtpao19Gx7QWXQKmQF8EWqO+OG5e0zHz+Be717Jjcxenzlxg7Yq2mX9QNUKXkCnQF4n2fKm/c2MTAFy3fCnffelVXh+d4OzYJGtX5sjNlFy0Fl3CpUBfZKqn18/5i1FY3/kT1/H04Os8/6NzAJeN0FVykZAp0BeZ6un1c24sH+g3vInJaecrx6P/j1ENXSUXCZ8CfZGpnl4/58YmWJZdwrvftpIlBn/2dy8DCnRJDwV6A6ieXh/nL06yfFmW5cuyvKvrGl4fneDqZS1ck8uyrFUlFwmfAr2BdL/L2jo7NsnypdHyxJVt2Zljtz3wGH954oeARugSNgV6A1Wqpw+NjGqCdB7OjU2yfFkLvceGePyFMzPHh0ZG+VjvCTKmQJew6dL/BipsDbCvb+CKy9Z1U4y5Oz82wfJlWfb1DTA+dfkmXKMTUxgquUjYYo3QzWybmQ2Y2Ukz6ynz+r8ws6fyfx43s1tq39QwFerp5SZJNUE6N+fGJrlqaUvFOQhH69AlbFUD3cwyRPcJvQvYBNxrZptKTnsJ+Bl3vxn4feBArRsaukohpAnS+Aoll0pzEC1LjNEJbZ8r4YozQt8CnHT3F919HDgIbC8+wd0fd/fX8k+/A6ypbTPDpwnShSuscik3N5HLZrhu+VLdtUiCFifQu4BTRc8H88cq+TDw1XIvmNkuM+s3s/7h4eH4rUwBTZAuzNS05wO95Yr96Ls6ctx/z02s7shpUlSCFmdStNxy6XK3zMTM3kcU6LeXe93dD5Avx3R3d5d9j7TSBOnCFC77X74s+pEu3o++4MtPDs6cJxKiOCP0QWBt0fM1wOnSk8zsZuBhYLu7v1qb5qWLJkjnr7AxVyHQy8llM7pjkQQtTqAfBTaY2XozawV2AoeLTzCz64FDwL909+dr38x00QTp3F0aoWcrntPWmlHJRYJWteTi7pNmdh/QB2SAR9z9hJntzr++H/g4cC3wUP6O6pPu3l2/Zoet0u3UNEFaWWFjrquWzjJCb81ccSNpkZDEWofu7kfc/Z3u/nZ3/0T+2P58mOPu/8bdV7j7rfk/CvMF0ATp3MUrubSo5CJB05WiTUgTpHNXGKFXL7lM4u7kf5MUCYr2cmlSmiCdm0uBPnvJZdq5YlsAkVAo0JucJkjjiRXo+TKWyi4SKgV6k9MVpPGcvzhBZolVvBsUMHNfUU2MSqgU6E1OE6TxFDbmmq02rrsWSeg0KdrkNEEaT2Fjrtmo5CKh0wg9ATRBWl0U6JVXuIBKLhI+BXqCaIK0snNjEzO3n6tEJRcJnQI9QTRBWlm8kkv0urbQlVAp0BOk3ATpsuwS9mzd2KAWNY9zFyeqB7pKLhI4BXqClO7zDeAOH/3C8dSveDk/NslVVQJdJRcJnQI9YQoTpH/4S7eSMePi5DTOpRUvaQx1d5/bpKgCXQKlQE+ofX0DTPnl9whJ64qXsYlpJqddyxYl9RToCaUVL5ecu5jfabHKKpdsZgnZjHFBNXQJlAI9obTi5ZI4Oy0WLNNdiyRgCvSE0pYAl8TZmKugrVWBLuHSpf8JpS0BLjkf425FBW2tLSq5SLA0Qk8wbQkQuXS3IpVcJN1iBbqZbTOzATM7aWY9ZV6/wcy+bWYXzex3at9MmU3aJ0jnXHKZ0JWiEqaqgW5mGeBB4C5gE3CvmW0qOe0M8JvAH9S8hVJV2idIz12MAvrqGCP06DZ0GqFLmOKM0LcAJ939RXcfBw4C24tPcPdX3P0oMFGHNkoVaZ8gLZRc2pdWvrlFgUouErI4gd4FnCp6Ppg/NmdmtsvM+s2sf3h4eD5vIWUUbwlQKg1XkJ4bm6StNUNLpvqPc1RyUaBLmOIEerlbwHiZY1W5+wF373b37s7Ozvm8hVSQ5gnS8zF2WizIZVVykXDFCfRBYG3R8zXA6fo0RxYqjROk5y5OxFqyCNF+LmMKdAlUnEA/Cmwws/Vm1grsBA7Xt1kyX2mcII2zMVdBW2uGCxNTuM/rl0yRplY10N19ErgP6AOeBb7o7ifMbLeZ7QYwszeb2SDw74H/bGaDZnZ1PRsu5aVxgvTsHEsuU9PO+NR0nVslsvhi/S1w9yPAkZJj+4se/5CoFCMNlsYrSM+PTdDVsSzWubnW6Ed+bHyapS3VV8WIJImuFA1Q2iZIz41Nsnxp/JILwAVdXCQBUqAHLC0TpHHuJ1qgPdElZAr0gFWaCHUIpp5+6IlTjE5M8fD/fSlWn3K6DZ0ETIEesEoTpBDGBUe9x4b4T73PzDyP06c23ShaAqZAD9hsV5BC8uvp+/oGGJu4fLVKtT6p5CIhU6AHrjBBWu5yX0h2PX0+cwQquUjIFOgpEWI9fT4XUc2M0LXKRQKkQE+JEOvpv/Nz77ziWC6bYc/WjRW/pi2/Dn10XBcWSXgU6CkRYj395rUdAHTkshjQ1ZHj/ntumvWiqUslF43QJTy6p2iK7NjcxY7NXazvebTsdplDI6Os73mU1R059mzd2PRXkz5+8scA9H7kNtatao/1NZoUlZBphJ5Cs9WYneSUYB5/4VVWX7OMt17bFvtrWluW0LLEtGxRgqRAT6HZ6ukFoxNT/PYXjjfthOn0tPPtF1/lPe9YhVmlNTxX6j02xNS089Bfv9C0fROZL5VcUqh4A6/TI6Oz3q2kWTf0+t7LZxm5MMFt77g29tf0Hhti76GnZ/rbrH0TmS+N0FOqsD79pQfeX3GitKAZR+uPvxDVz9/z9lWxv2Zf38AVpZYkTgaLVKJAl1glGIhGtB/9wnHW9Tza0HDvPTbEf/368wDc89DjsdtR6YKjkPeKl3SxRt25pbu72/v7+xvy2XKl3mNDZfdQn40RTaJ25LKYwciFibqvkOk9NkTPoacuu+Q/l81UXa4I0QVUs/Wv0J+uhKzykXQysyfcvbvsawp0KVaoMy9kFUi5oL9mAaFf+Mfm9MgoSwymyvzIdnXk+FbPHVXfJ27fat0HkVpZcKCb2Tbg00AGeNjdHyh53fKv3w1cAH7N3Z+c7T0V6M1rPqP1uaoWmIXHr12YmDm32vu99MD7q35uLftW6TeU993QyV89N8zpkdGyfWqWx83e1mZv30LbOt9BwYIC3cwywPPAzwKDRDeNvtfdv1d0zt3AbxAF+ruBT7v7u2d7XwV686vFaH2xxBmhF6tWfhFZDHHLhcVmC/Q4k6JbgJPu/qK7jwMHge0l52wH/qdHvgN0mNlbYrdQmlLpdgHxV3svrmr7t5QTdyJYpJ5qvcoqzjr0LuBU0fNBolF4tXO6gJeLTzKzXcAugOuvv36ubZUGKGwXAJeXK+KUQeopY8a0+7x/bS29mXaj+yPpVcstrOMEermBWenPfpxzcPcDwAGISi4xPluaSLlwL64Xxq13L9R8fk0tp1n/sZJ0mW0rjrmKE+iDwNqi52uA0/M4RwJSHIbFygX9fCY5S9V7SWG1f6xq0QeRUvMpF84mzqRoC9Gk6J3AENGk6C+7+4mic94P3MelSdE/cvcts72vJkXTq1pgJmGJYKU+hL4yQ+1L+CqX/BvcDXyKaNniI+7+CTPbDeDu+/PLFj8DbCNatviv3X3WtFagi4jM3WyBHmtzLnc/AhwpOba/6LEDH1lII0VEZGG0l4uISCAU6CIigVCgi4gEQoEuIhKIhu22aGbDwD/M88tXAT+uYXOSIo39TmOfIZ39TmOfYe79fqu7d5Z7oWGBvhBm1l9p2U7I0tjvNPYZ0tnvNPYZattvlVxERAKhQBcRCURSA/1AoxvQIGnsdxr7DOnsdxr7DDXsdyJr6CIicqWkjtBFRKSEAl1EJBCJC3Qz22ZmA2Z20sx6Gt2eejCztWb2V2b2rJmdMLPfyh9faWZ/YWZ/n//vika3tdbMLGNmx8zsz/PP09DnDjP7kpk9l/+e/3RK+v3R/M/3M2b2eTNbFlq/zewRM3vFzJ4pOlaxj2a2N59tA2a2da6fl6hAz9+w+kHgLmATcK+ZbWpsq+piEvgP7v4TwE8BH8n3swf4hrtvAL6Rfx6a3wKeLXqehj5/Gviau98A3ELU/6D7bWZdwG8C3e7+LqKtuXcSXr8/S7SteLGyfcz/Hd8J3Jj/mofymRdbogKdeDesTjx3f9ndn8w/Pkf0F7yLqK+fy5/2OWBHQxpYJ2a2Bng/8HDR4dD7fDXwT4E/AXD3cXcfIfB+57UAufxNdNqI7nIWVL/d/ZvAmZLDlfq4HTjo7hfd/SXgJFHmxZa0QK90M+pgmdk6YDPwXeBN7v4yRKEPXNfAptXDp4D/CEwXHQu9z28DhoH/kS81PWxm7QTeb3cfAv4A+AHRzeRfd/evE3i/8yr1ccH5lrRAj3Uz6lCY2VXAl4HfdvezjW5PPZnZzwOvuPsTjW7LImsBfhL4Y3ffDLxB8ssMVeXrxtuB9cBqoN3MfqWxrWq4Bedb0gI9NTejNrMsUZj/H3c/lD/8IzN7S/71twCvNKp9dXAb8AEz+z5RKe0OM/vfhN1niH6mB939u/nnXyIK+ND7/c+Al9x92N0ngEPAewi/31C5jwvOt6QF+lFgg5mtN7NWogmEww1uU83l79H6J8Cz7v7fil46DPxq/vGvAl9Z7LbVi7vvdfc17r6O6Pv6mLv/CgH3GcDdfwicMrPCrd/vBL5H4P0mKrX8lJm15X/e7ySaKwq931C5j4eBnWa21MzWAxuA/zend3b3RP0B7gaeB14APtbo9tSpj7cT/ar1FHA8/+du4FqiWfG/z/93ZaPbWqf+vxf48/zj4PsM3Ar057/fvcCKlPT794DngGeA/wUsDa3fwOeJ5ggmiEbgH56tj8DH8tk2ANw118/Tpf8iIoFIWslFREQqUKCLiARCgS4iEggFuohIIBToIiKBUKCLiARCgS4iEoj/D7CrobJDlXAEAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "# learns to repeat simple sequence from random inputs\n",
    "np.random.seed(0)\n",
    "loss_plt=[]\n",
    "# parameters for input data dimension and lstm cell count \n",
    "epochs = 100\n",
    "mem_cell_ct = 100\n",
    "x_dim = 50\n",
    "concat_len = x_dim + mem_cell_ct\n",
    "lstm_param = LstmParam(mem_cell_ct, x_dim) \n",
    "lstm_net = LstmNetwork(lstm_param)\n",
    "y_list = [-0.5,0.2,0.1, -0.5]\n",
    "input_val_arr = [np.random.random(x_dim) for _ in y_list]\n",
    "\n",
    "for cur_iter in range(epochs):\n",
    "    #print(\"cur iter: \", cur_iter)\n",
    "    for ind in range(len(y_list)):\n",
    "        lstm_net.x_list_add(input_val_arr[ind])\n",
    "        #print(\"y_pred[%d] : %f\" % (ind, lstm_net.lstm_node_list[ind].state.h[0]))\n",
    "\n",
    "    loss = lstm_net.y_list_is(y_list, ToyLossLayer)\n",
    "    loss_plt.append(loss)\n",
    "    #print(\"loss: \", loss)\n",
    "    lstm_param.apply_diff(lr=0.1)\n",
    "    lstm_net.x_list_clear()\n",
    "plt.plot(range(0,epochs),loss_plt,'o-',label='loss')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
